from __future__ import annotations

import argparse
import json
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import statsmodels.api as sm
from matplotlib import pyplot as plt

TAKE_HOME_MESSAGE = (
    "More media/information freedom is associated with greater transparency during COVID "
    "(smaller excess-vs-reported mortality gaps) and lower true health losses, consistent "
    "with information environments functioning as public health capacity—especially for vulnerable groups."
)


@dataclass(frozen=True)
class Window:
    name: str
    start: str  # inclusive YYYY-MM
    end: str  # inclusive YYYY-MM


def month_str_to_ts(month_id: str) -> pd.Timestamp:
    return pd.Period(month_id, freq="M").to_timestamp()


def filter_panel(panel: pd.DataFrame, min_population: int) -> pd.DataFrame:
    d = panel.copy()
    d["month"] = pd.to_datetime(d["month"])
    d["month_id"] = d["month"].dt.to_period("M").astype(str)

    d = d[d["media_freedom"].notna()].copy()
    d = d[d["population"].notna()].copy()
    d = d[d["population"] >= min_population].copy()

    if "gdp_per_capita" in d.columns:
        d["log_gdp_pc"] = np.log(d["gdp_per_capita"])

    return d


def aggregate_country_window(
    d: pd.DataFrame,
    window: Window,
    outcome: str,
    outcome_mode: str,
    min_months_nonmissing: int,
) -> pd.DataFrame:
    start_ts = month_str_to_ts(window.start)
    end_ts = month_str_to_ts(window.end)
    dd = d[(d["month"] >= start_ts) & (d["month"] <= end_ts)].copy()

    if outcome_mode == "sum":
        nonmissing = dd.groupby("iso_code")[outcome].apply(lambda s: int(s.notna().sum())).rename("n_months_nonmissing")
        dd = dd.merge(nonmissing.reset_index(), on="iso_code", how="left")
        dd = dd[dd["n_months_nonmissing"] >= min_months_nonmissing].copy()
    elif outcome_mode == "cum_delta":
        # Require at least 2 cumulative observations in the window.
        nonmissing = dd.groupby("iso_code")[outcome].apply(lambda s: int(s.notna().sum())).rename("n_months_nonmissing")
        dd = dd.merge(nonmissing.reset_index(), on="iso_code", how="left")
        dd = dd[dd["n_months_nonmissing"] >= 2].copy()
    else:
        raise ValueError(f"Unknown outcome_mode: {outcome_mode}")

    cols_max = [
        "population",
        "gdp_per_capita",
        "median_age",
        "hospital_beds_per_thousand",
        "diabetes_prevalence",
        "human_development_index",
        "extreme_poverty",
        "continent",
        "location",
        "media_freedom",
        "freedom_expression",
        "alternative_info",
        "baseline_year",
    ]
    # Optional baseline controls (if present)
    cols_max += [c for c in dd.columns if c.startswith("wgi_")]
    cols_max += [c for c in dd.columns if c.startswith("death_reg_")]
    cols_max = [c for c in cols_max if c in dd.columns]

    if outcome_mode == "sum":
        y_series = dd.groupby("iso_code")[outcome].sum(min_count=min_months_nonmissing).rename("y")
    else:
        # Delta of cumulative within window: last_non_null - first_non_null.
        def delta_cum(g: pd.DataFrame) -> float:
            s = g.sort_values("month")[outcome].dropna()
            if s.shape[0] < 2:
                return np.nan
            return float(s.iloc[-1] - s.iloc[0])

        y_series = dd.groupby("iso_code", group_keys=False).apply(delta_cum, include_groups=False).rename("y")

    agg = (
        dd.groupby(["iso_code"], as_index=False)
        .agg(
            reported_deaths_pm=("reported_deaths_pm", "sum"),
            reported_cases_pm=("reported_cases_pm", "sum"),
            months=("month_id", "nunique"),
            **{c: (c, "max") for c in cols_max},
        )
        .merge(y_series.reset_index(), on="iso_code", how="left")
    )
    if "gdp_per_capita" in agg.columns:
        agg["log_gdp_pc"] = np.log(agg["gdp_per_capita"])
    agg["window"] = window.name
    agg["window_start"] = window.start
    agg["window_end"] = window.end
    agg["outcome_mode"] = outcome_mode
    return agg


def run_regression(
    df: pd.DataFrame,
    y_col: str,
    x_col: str,
    controls: list[str],
    weight_col: str | None,
    robust: str,
) -> dict:
    use_controls = [c for c in controls if c in df.columns]
    cols = [y_col, x_col] + use_controls + ([weight_col] if weight_col and weight_col in df.columns else [])
    d = df[cols].dropna().copy()
    if d.empty:
        return {}

    X = sm.add_constant(d[[x_col] + use_controls], has_constant="add")
    y = d[y_col].astype(float)

    if weight_col:
        w = d[weight_col].astype(float)
        model = sm.WLS(y, X, weights=w)
    else:
        model = sm.OLS(y, X)

    if robust.lower() == "hc1":
        fit = model.fit(cov_type="HC1")
    else:
        fit = model.fit()

    return {
        "n": int(fit.nobs),
        "coef": float(fit.params[x_col]),
        "se": float(fit.bse[x_col]),
        "p": float(fit.pvalues[x_col]),
        "r2": float(getattr(fit, "rsquared", np.nan)),
    }


def coefplot(results: pd.DataFrame, out_path: Path, title: str) -> None:
    if results.empty:
        return
    d = results.copy()
    d["ci_lo"] = d["coef"] - 1.96 * d["se"]
    d["ci_hi"] = d["coef"] + 1.96 * d["se"]
    d = d.sort_values(["outcome", "outcome_mode", "window", "spec"])

    labels = (d["outcome"] + " (" + d["outcome_mode"] + ") | " + d["window"] + " | " + d["spec"]).tolist()
    y_pos = np.arange(len(d))

    plt.figure(figsize=(12, max(6, 0.35 * len(d) + 2)))
    plt.errorbar(d["coef"], y_pos, xerr=1.96 * d["se"], fmt="o", color="black", ecolor="black", capsize=2)
    plt.axvline(0, color="gray", linewidth=1)
    plt.yticks(y_pos, labels, fontsize=8)
    plt.title(title)
    plt.xlabel("Coefficient on media freedom (baseline)")
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=200)
    plt.close()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--panel", type=Path, default=Path("outputs/data/panel_merged.parquet"))
    parser.add_argument("--out-dir", type=Path, default=Path("outputs"))
    parser.add_argument("--min-population", type=int, default=1_000_000)
    parser.add_argument("--min-months-nonmissing", type=int, default=12)
    args = parser.parse_args()

    out_dir = args.out_dir
    (out_dir / "tables").mkdir(parents=True, exist_ok=True)
    (out_dir / "figures").mkdir(parents=True, exist_ok=True)

    panel = pd.read_parquet(args.panel)
    d = filter_panel(panel, min_population=args.min_population)

    windows = [
        Window("2020-2021", "2020-02", "2021-12"),
        Window("2020-2022", "2020-02", "2022-12"),
        Window("2020-2023", "2020-02", "2023-12"),
    ]

    outcomes = [
        ("gap_pm_excess_minus_reported", "sum", "Gap pm (sum excess - sum reported)"),
        ("gap_pm_clip0", "sum", "Gap pm (clip excess≥0, then sum - sum reported)"),
        ("excess_deaths_pm", "sum", "Excess deaths pm (sum)"),
        ("excess_deaths_pm_clip0", "sum", "Excess deaths pm (sum, clip≥0)"),
        ("excess_cum_pm", "cum_delta", "Excess deaths pm (delta cumulative)"),
        ("excess_cum_abs", "cum_delta", "Excess deaths (delta cumulative abs)"),
    ]

    base_controls = [
        "log_gdp_pc",
        "median_age",
        "hospital_beds_per_thousand",
        "diabetes_prevalence",
        "human_development_index",
    ]
    base_controls = [c for c in base_controls if c in d.columns]

    specs = [
        ("OLS_min_HC1", None, "HC1", []),
        ("OLS_full_HC1", None, "HC1", base_controls),
        ("WLS_pop_full_HC1", "population", "HC1", base_controls),
    ]

    rows: list[dict] = []
    sample_rows: list[dict] = []
    alt_x_rows: list[dict] = []
    loo_rows: list[dict] = []
    dq_rows: list[dict] = []
    completeness_rows: list[dict] = []

    for win in windows:
        for out_col, out_mode, out_label in outcomes:
            agg = aggregate_country_window(
                d, win, out_col, out_mode, min_months_nonmissing=args.min_months_nonmissing
            )

            sample_rows.append(
                {
                    "window": win.name,
                    "outcome": out_col,
                    "outcome_mode": out_mode,
                    "countries": int(agg["iso_code"].nunique()),
                    "mean_y": float(agg["y"].mean()) if not agg.empty else np.nan,
                    "median_y": float(agg["y"].median()) if not agg.empty else np.nan,
                }
            )

            for spec_name, weight_col, robust, controls in specs:
                res = run_regression(
                    agg,
                    y_col="y",
                    x_col="media_freedom",
                    controls=controls,
                    weight_col=weight_col,
                    robust=robust,
                )
                if not res:
                    continue
                rows.append(
                    {
                        "window": win.name,
                        "window_start": win.start,
                        "window_end": win.end,
                        "outcome": out_col,
                        "outcome_mode": out_mode,
                        "outcome_label": out_label,
                        "spec": spec_name,
                        "min_population": args.min_population,
                        "min_months_nonmissing": args.min_months_nonmissing,
                        "controls": "+".join(controls),
                        "n": res["n"],
                        "coef": res["coef"],
                        "se": res["se"],
                        "p": res["p"],
                        "r2": res["r2"],
                    }
                )

    # Focused robustness: alternative "X" measures of information environment
    alt_x_vars = [
        ("media_freedom", "v2x_freexp_altinf"),
        ("freedom_expression", "v2x_freexp"),
        ("alternative_info", "v2xme_altinf"),
    ]
    focus_window = Window("2020-2022", "2020-02", "2022-12")
    focus_outcome = ("gap_pm_excess_minus_reported", "sum", "Gap pm (sum excess - sum reported)")
    focus_agg = aggregate_country_window(
        d, focus_window, focus_outcome[0], focus_outcome[1], min_months_nonmissing=args.min_months_nonmissing
    )
    for x_var, x_label in alt_x_vars:
        if x_var not in focus_agg.columns:
            continue
        res = run_regression(
            focus_agg,
            y_col="y",
            x_col=x_var,
            controls=base_controls,
            weight_col=None,
            robust="HC1",
        )
        if not res:
            continue
        alt_x_rows.append(
            {
                "window": focus_window.name,
                "outcome": focus_outcome[0],
                "spec": "OLS_full_HC1",
                "x_var": x_var,
                "x_label": x_label,
                "controls": "+".join(base_controls),
                "n": res["n"],
                "coef": res["coef"],
                "se": res["se"],
                "p": res["p"],
            }
        )

    # Leave-one-continent-out for the main transparency result
    if "continent" in d.columns:
        for cont in sorted([c for c in d["continent"].dropna().unique().tolist() if str(c).strip()]):
            dd = d[d["continent"] != cont].copy()
            agg = aggregate_country_window(
                dd,
                focus_window,
                focus_outcome[0],
                focus_outcome[1],
                min_months_nonmissing=args.min_months_nonmissing,
            )
            res = run_regression(
                agg,
                y_col="y",
                x_col="media_freedom",
                controls=base_controls,
                weight_col=None,
                robust="HC1",
            )
            if not res:
                continue
            loo_rows.append(
                {
                    "excluded_continent": cont,
                    "window": focus_window.name,
                    "outcome": focus_outcome[0],
                    "spec": "OLS_full_HC1",
                    "controls": "+".join(base_controls),
                    "n": res["n"],
                    "coef": res["coef"],
                    "se": res["se"],
                    "p": res["p"],
                }
            )

    # Data-quality / state-capacity robustness
    dq_specs: list[tuple[str, list[str]]] = [
        ("base_controls", base_controls),
        ("+death_reg", base_controls + (["death_reg_cod_pct"] if "death_reg_cod_pct" in focus_agg.columns else [])),
        ("+gov_effectiveness", base_controls + (["wgi_gov_effectiveness"] if "wgi_gov_effectiveness" in focus_agg.columns else [])),
        (
            "+WGI_all",
            base_controls
            + [c for c in [
                "wgi_gov_effectiveness",
                "wgi_rule_of_law",
                "wgi_control_corruption",
                "wgi_voice_accountability",
                "wgi_political_stability",
                "wgi_regulatory_quality",
            ] if c in focus_agg.columns],
        ),
        (
            "+WGI_all+death_reg",
            base_controls
            + [c for c in [
                "wgi_gov_effectiveness",
                "wgi_rule_of_law",
                "wgi_control_corruption",
                "wgi_voice_accountability",
                "wgi_political_stability",
                "wgi_regulatory_quality",
            ] if c in focus_agg.columns]
            + (["death_reg_cod_pct"] if "death_reg_cod_pct" in focus_agg.columns else []),
        ),
    ]

    def append_dq_rows(agg: pd.DataFrame, window_name: str, outcome: str, sample_name: str) -> None:
        for name, controls in dq_specs:
            res = run_regression(
                agg,
                y_col="y",
                x_col="media_freedom",
                controls=controls,
                weight_col=None,
                robust="HC1",
            )
            if res:
                dq_rows.append(
                    {
                        "sample": sample_name,
                        "window": window_name,
                        "outcome": outcome,
                        "spec": name,
                        "controls": "+".join([c for c in controls if c]),
                        "n": res["n"],
                        "coef": res["coef"],
                        "se": res["se"],
                        "p": res["p"],
                    }
                )

    append_dq_rows(focus_agg, focus_window.name, focus_outcome[0], "all")
    if "death_reg_cod_pct" in focus_agg.columns:
        append_dq_rows(focus_agg[focus_agg["death_reg_cod_pct"].notna()].copy(), focus_window.name, focus_outcome[0], "death_reg_nonmissing")

    # Explicit 2023 completeness check (same countries, different inclusion rule)
    # This addresses the "is 2023 incomplete?" question by enforcing balanced coverage in 2023.
    if focus_window.name == "2020-2022":
        pass
    # Build 2020-2023 window for the same transparency outcome
    win_2023 = Window("2020-2023", "2020-02", "2023-12")
    agg_2023 = aggregate_country_window(
        d, win_2023, focus_outcome[0], focus_outcome[1], min_months_nonmissing=args.min_months_nonmissing
    )
    # Add per-country nonmissing months in 2023 for the outcome
    d3 = d[d["month_id"].between("2023-01", "2023-12")].copy()
    m3 = d3.groupby("iso_code")[focus_outcome[0]].apply(lambda s: int(s.notna().sum())).rename("months_nonmissing_2023")
    agg_2023 = agg_2023.merge(m3.reset_index(), on="iso_code", how="left")
    agg_2023["months_nonmissing_2023"] = agg_2023["months_nonmissing_2023"].fillna(0).astype(int)

    for threshold in [0, 6, 9, 12]:
        sub = agg_2023[agg_2023["months_nonmissing_2023"] >= threshold].copy()
        res = run_regression(
            sub,
            y_col="y",
            x_col="media_freedom",
            controls=base_controls + (["death_reg_cod_pct"] if "death_reg_cod_pct" in sub.columns else []),
            weight_col=None,
            robust="HC1",
        )
        if res:
            completeness_rows.append(
                {
                    "window": win_2023.name,
                    "outcome": focus_outcome[0],
                    "threshold_months_nonmissing_2023": threshold,
                    "n": res["n"],
                    "coef": res["coef"],
                    "se": res["se"],
                    "p": res["p"],
                    "controls": "+".join([c for c in (base_controls + (["death_reg_cod_pct"] if "death_reg_cod_pct" in sub.columns else [])) if c in sub.columns]),
                }
            )

    # Also compute data-quality robustness for 2020-2023 (on the full sample and on death-reg-nonmissing subset)
    append_dq_rows(agg_2023, win_2023.name, focus_outcome[0], "all")
    if "death_reg_cod_pct" in agg_2023.columns:
        append_dq_rows(agg_2023[agg_2023["death_reg_cod_pct"].notna()].copy(), win_2023.name, focus_outcome[0], "death_reg_nonmissing")

    results = pd.DataFrame(rows)
    results_path = out_dir / "tables" / "main_specs.csv"
    results.to_csv(results_path, index=False, encoding="utf-8")

    sample = pd.DataFrame(sample_rows)
    sample_path = out_dir / "tables" / "sample_summary.csv"
    sample.to_csv(sample_path, index=False, encoding="utf-8")

    alt_x = pd.DataFrame(alt_x_rows)
    (out_dir / "tables" / "robustness_alt_x.csv").write_text("", encoding="utf-8")
    alt_x_path = out_dir / "tables" / "robustness_alt_x.csv"
    alt_x.to_csv(alt_x_path, index=False, encoding="utf-8")

    loo = pd.DataFrame(loo_rows)
    loo_path = out_dir / "tables" / "robustness_leave_one_continent.csv"
    loo.to_csv(loo_path, index=False, encoding="utf-8")

    dq = pd.DataFrame(dq_rows)
    dq_path = out_dir / "tables" / "robustness_data_quality.csv"
    dq.to_csv(dq_path, index=False, encoding="utf-8")

    completeness = pd.DataFrame(completeness_rows)
    completeness_path = out_dir / "tables" / "robustness_2023_completeness.csv"
    completeness.to_csv(completeness_path, index=False, encoding="utf-8")

    meta = {
        "take_home_message": TAKE_HOME_MESSAGE,
        "note": "Cross-sectional window aggregations; clip variants address negative revisions in excess mortality series.",
    }
    (out_dir / "tables" / "main_specs_meta.json").write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")

    coefplot(
        results,
        out_dir / "figures" / "coefplot_media_freedom_main_specs.png",
        title="Media freedom → transparency gap / excess mortality (robustness across windows & definitions)\n"
        + TAKE_HOME_MESSAGE,
    )

    print(f"Wrote {results_path}")
    print(f"Wrote {sample_path}")
    print(f"Wrote {alt_x_path}")
    print(f"Wrote {loo_path}")
    print(f"Wrote {dq_path}")
    print(f"Wrote {completeness_path}")


if __name__ == "__main__":
    main()
